function [] = summarize_inequality_analysis(wave,sample_cross,sample_panel,savedir_figures,fbase)
% This function produces results for the inequality analysis. 

% inputs: 
% (1) wave: the pooled sample (=0)
% (2) sample_cross: cross-sectional dataset with electricity consumption, 
% covariates, and propensity scores
% (3) sample_panel: panel dataset with electricity consumption, 
% covariates, and propensity scores
% (4) savedir_figures: directory to save the figures
% (5) fbase: string used to indicate the propensity score and baseline
% months specification for output table 

% outputs: 
% saved figures and tables 

%% Define colors 
color_full='#FF00FF';
color_RCT='b';
color_univariate='#F8766D';
color_quadrant='#00BA38';
color_cubic='#619CFF';

%% Empirical CDF of income for treated households

% Quadrant and cubic rules ------------------------------------------------
covariates={'income','size','min'};
s_cost='PMC';
s_winter='';

for i=1:size(covariates,2)
covariate=covariates{i};

quadrant_filename_coefs = sprintf('coef_quadrant_%s_%s%s_wave%1.0f.mat',covariate,s_cost,s_winter,wave);
cubic_filename_coefs = sprintf('coef_cubic_%s_%s%s_wave%1.0f.mat',covariate,s_cost,s_winter,wave);

% Get household level income data for EWM results (originally saved for
% UNIQUE households) 
n_fullsample=size(sample_cross,1); % number of households in the full sample

% quadrant rule
filename = sprintf('%s_%s',fbase,quadrant_filename_coefs);
load(filename);
D_hh_quadrant=in_Ghat_hh; 
income_hh_quadrant=sample_cross.income(Ind); 

% cubic rule
filename = sprintf('%s_%s',fbase,cubic_filename_coefs);
load(filename);
D_hh_cubic=in_Ghat; 
income_hh_cubic=sample_cross.income; 

% EWM rule CDF plot set up
f1=sample_cross.income;
f2=sample_cross.income(sample_cross.opower==1);
f3=income_hh_quadrant(D_hh_quadrant==1);
f4=income_hh_cubic(D_hh_cubic==1);

dens_income=figure; hold on;
h(1,1)=cdfplot(f1);
h(1,2)=cdfplot(f2); 
h(1,3)=cdfplot(f3);
h(1,4)=cdfplot(f4);
lgd1='Entire pooled sample';
lgd2='Original experiment';
lgd3='Quadrant rule';
lgd4='Cubic rule';
legend({lgd1,lgd2,lgd3,lgd4},'Location','northwest');
set(h(:,1),'Color',color_full,'LineWidth',2)
set(h(:,2),'LineStyle','--','Color',color_RCT,'LineWidth',2)
set(h(:,3),'Color',color_quadrant,'LineWidth',2)
set(h(:,4),'Color',color_cubic,'LineWidth',2)
xlabel('Income [$k]'); 
ylabel('Cumulative share of treated housholds');
comb_figure_name=sprintf('inequality_cdf_%s_%s_comb.png',covariate,s_cost);
saveas(dens_income,fullfile(savedir_figures,comb_figure_name))
end

% Univariate --------------------------------------------------------------
onedim_filename_coefs = sprintf('coef_onedim_income_baseline_%s%s_wave%1.0f.mat',s_cost,s_winter,wave);
% Get household level income data for EWM results (originally saved for
% UNIQUE households) 
n_fullsample=size(sample_cross,1); 

filename = sprintf('%s_%s',fbase,onedim_filename_coefs);
load(filename);
D_hh_onedim=in_Ghat_hh; 
income_hh_onedim=sample_cross.income(Ind); 
f1=sample_cross.income;
f2=sample_cross.income(sample_cross.opower==1);
f3=income_hh_onedim(D_hh_onedim==1);
dens_income=figure; hold on;
h(1,1)=cdfplot(f1);
h(1,2)=cdfplot(f2);
h(1,3)=cdfplot(f3);
lgd1='Entire pooled sample';
lgd2='Original experiment';
lgd3='One-dimensional rule';
legend({lgd1,lgd2,lgd3},'Location','northwest');
set(h(:,1),'Color',color_full,'LineWidth',2)
set(h(:,2),'LineStyle','--','Color',color_RCT,'LineWidth',2)
set(h(:,3),'Color',color_univariate,'LineWidth',2)
xlabel('Income [$k]'); 
ylabel('Cumulative share of treated housholds');
comb_figure_name=sprintf('inequality_cdf_onedim_%s_comb.png',s_cost);
saveas(dens_income,fullfile(savedir_figures,comb_figure_name))

%% Racial diversity index v. treatment share @ zip code level

% add zip code to cross-sectional sample -----------------
account_zip=unique(sample_panel(:,{'ky_ba','ad_serv_zip'}),'rows'); 
clear sample_panel
sample_cross_zip=join(sample_cross,account_zip);

% import and clean race data ------------------------
race_data=readtable('.\input\acs2019_race\acs2019_5yr_B02001_86000US02895');
race_data=race_data(2:end,:); % exclude the state total row 
race_data.Properties.VariableNames{'name'}='zip'; 
race_type={'white','black','native','asian','hawaiian','other','two_plus'}; 
pos=5:2:18; 
for i=1:size(race_type,2)
    race_data.Properties.VariableNames{pos(i)}=race_type{i}; 
end    

keep_var=['zip',race_type]; 
race_data_cleaned=race_data(:,keep_var); % cleaned table: row is zip code, column is race type 

% calculate percentage of each race type ----------------------
race_total=sum(table2array(race_data_cleaned(:,2:end)),2); 
race_comp=table2array(race_data_cleaned(:,2:end))./race_total; 
assert(all(abs(sum(race_comp,2)-1)<0.001)) % confirm race type add up to 1 

zip = race_data_cleaned.zip;
zip_city=readtable('.\input\acs2019_race\zip_city');
diversity_idx_zip = zip_city; 

% figure for diversity index by zip code ENTIRE STATE ----------------------
index_all=figure; 
histogram(diversity_idx_zip.diversity_idx,20) 
xlim([0 1]); ylim([0 20]);
xlabel('Diversity Index'); ylabel('Frequency')
figure_name='Histogram_diversity_index_all_zip.png';
saveas(index_all,fullfile(savedir_figures,figure_name))

% Share treated under RCT by zip code -----------------------------------
share_fun = @(x) mean(x,'omitnan');
treat_share_zip = varfun(share_fun,sample_cross_zip,'GroupingVariables','ad_serv_zip',...
                         'InputVariables',{'opower'});
treat_share_zip.Properties.VariableNames{'Fun_opower'}='share_treated_RCT'; 

% Share treated under EWM rules by zip ------------------------------------
cost={false true true}; 
SMC={0 0 1}; 
rules={'univariate','quadrant','cubic'};
wave=0;
covariates={'income','size','vintage','min','max','std'};
s_winter='';
ewm_output=cell(size(rules,2),size(cost,2),size(covariates,2));

for j=1:size(rules,2)
    rule = rules{j};
    for c=1:size(covariates,2)
        covariate=covariates{c};
        for i=1:size(cost,2)
             if (cost{i})
                if (SMC{i})
                    s_cost = 'SMC';
                else
                    s_cost = 'PMC';
                end
            else
                s_cost='kwh';
             end
        switch rule 
            case 'quadrant'
                filename_coefs = sprintf('coef_%s_%s_%s%s_wave%1.0f.mat',rule,covariate,s_cost,s_winter,wave);
            case 'cubic'    
                filename_coefs = sprintf('coef_%s_%s_%s%s_wave%1.0f.mat',rule,covariate,s_cost,s_winter,wave);
            case 'univariate'
                filename_coefs = sprintf('coef_onedim_income_baseline_%s%s_wave%1.0f.mat',s_cost,s_winter,wave);

        end
        filename = sprintf('%s_%s',fbase,filename_coefs);

        % load previously saved EWM results 
        load(filename);

        % Get household level income data for EWM results (originally saved for
        % UNIQUE households) 
        n_fullsample=size(sample_cross,1); 
        switch covariate
            case 'size'
                covariate_check='uni_size';
            case 'min'
                covariate_check='min_baseline';
            case 'max'
                covariate_check='max_baseline';
            case 'std'
                covariate_check='std_baseline';
            otherwise 
                covariate_check=covariate;
        end
        switch rule 
            case 'quadrant'
                D_hh=in_Ghat_hh;
                covariate_quadrant=table2array(sample_cross(:,find(string(sample_cross.Properties.VariableNames)==covariate_check)));
                quadrant_rule_check=figure;
                gscatter(covariate_quadrant(Ind), sample_cross.pre_avg(Ind), D_hh, 'br', '..');
                zip_r=sample_cross_zip.ad_serv_zip(Ind);
                figure_name=sprintf('race_quadrant_rule_check_%s_%s.png',s_cost,covariate_check);
                saveas(quadrant_rule_check,fullfile(savedir_figures,figure_name))
            case 'cubic'
                D_hh=in_Ghat; 
                covariate_cubic=table2array(sample_cross(:,find(string(sample_cross.Properties.VariableNames)==covariate_check)));
                cubic_rule_check=figure;
                gscatter(covariate_cubic, sample_cross.pre_avg, D_hh, 'br', '..');
                zip_r=sample_cross_zip.ad_serv_zip; 
                figure_name=sprintf('race_cubic_rule_check_%s_%s.png',s_cost,covariate_check);
                saveas(cubic_rule_check,fullfile(savedir_figures,figure_name))
            case 'univariate'                
                D_hh = nan(n,1); 
                i_u=1;
                % get treatment assignment at the household level, instead
                % of unique household level
                for p=1:length(boo1)
                    bin_boo = logical(xx1==p);
                    d_ = in_Ghat(i_u);
                    D_hh(bin_boo,1) = d_;
                    i_u=i_u+1;
                end
                onedim_rule_check=figure;
                gscatter(ones(n,1),sample_cross.pre_avg(Ind,:), D_hh, 'kr');
                zip_r=sample_cross_zip.ad_serv_zip(Ind);
                figure_name=sprintf('race_onedim_rule_check_%s.png',s_cost);
                saveas(onedim_rule_check,fullfile(savedir_figures,figure_name))
        end
        ewm_output{j,i,c}=table(zip_r,D_hh);
        end
    end
end

treat_share_zip_ewm = cell(size(rules,2),size(cost,2),size(covariates,2));

for j=1:size(rules,2)
    rule = rules{j};
    for c=1:size(covariates,2)
        covariate=covariates{c};
        for i=1:size(cost,2)
            treat_share_zip_ewm{j,i,c} = varfun(share_fun,ewm_output{j,i,c},'GroupingVariables','zip_r',...
                         'InputVariables',{'D_hh'});
            treat_share_zip_ewm{j,i,c}.Properties.VariableNames{'Fun_D_hh'}='share_treated_ewm';
            treat_share_zip_ewm{j,i,c}.Properties.VariableNames{'zip_r'}='ad_serv_zip';
        end        
    end
end

% RCT treat share and ewm treat share bin scatter 
common_zip=intersect(diversity_idx_zip.ZipCode,treat_share_zip_ewm{1,1,1}.ad_serv_zip); % find zipcodes that are common in the state and in the opower waves - 69/81 zip codes 
zip_size=race_total(ismember(treat_share_zip.ad_serv_zip,common_zip,'rows'),:);

i=2; % PMC

if (cost{i})
    if (SMC{i})
        s_cost = 'SMC';
    else
        s_cost = 'PMC';
    end
else
    s_cost='kwh';
end

for c=1:size(covariates,2)
        set1=diversity_idx_zip.diversity_idx(ismember(diversity_idx_zip.ZipCode,common_zip,'rows'),:);
        set2=treat_share_zip.share_treated_RCT(ismember(treat_share_zip.ad_serv_zip,common_zip,'rows'),:);
        set3_quadrant=treat_share_zip_ewm{2,i,c}.share_treated_ewm(ismember(treat_share_zip_ewm{2,i,c}.ad_serv_zip,common_zip,'rows'),:); % share treated under each ewm rule 
        set3_cubic=treat_share_zip_ewm{3,i,c}.share_treated_ewm(ismember(treat_share_zip_ewm{3,i,c}.ad_serv_zip,common_zip,'rows'),:); % share treated under each ewm rule 
        
        comp_scatter=figure;
        scatter(set2,set1,zip_size/200,'MarkerEdgeColor',color_RCT,'MarkerFaceColor',color_RCT,'LineWidth',1.5,'MarkerFaceAlpha',0.7); % RCT
        hold on;
        scatter(set3_quadrant,set1,zip_size/200,'MarkerEdgeColor',color_quadrant,'MarkerFaceColor',color_quadrant,'LineWidth',1.5,'MarkerFaceAlpha',0.5); % EWM
        scatter(set3_cubic,set1,zip_size/200,'MarkerEdgeColor',color_cubic,'MarkerFaceColor',color_cubic,'LineWidth',1.5,'MarkerFaceAlpha',0.5); % EWM
        ylabel('Diversity index'); xlabel('Share treated');
        xlim([0 1]); ylim([0 1])
        legend({'Original experiment','Quadrant rule', 'Cubic rule'})
        figure_name=sprintf('inequality_race_%s_%s.png',s_cost,covariates{c});
        saveas(comp_scatter,fullfile(savedir_figures,figure_name))
end

 set1=diversity_idx_zip.diversity_idx(ismember(diversity_idx_zip.ZipCode,common_zip,'rows'),:);
 set2=treat_share_zip.share_treated_RCT(ismember(treat_share_zip.ad_serv_zip,common_zip,'rows'),:);
 set3_onedim=treat_share_zip_ewm{1,i,c}.share_treated_ewm(ismember(treat_share_zip_ewm{1,i,c}.ad_serv_zip,common_zip,'rows'),:); % share treated under each ewm rule 

 comp_scatter=figure;
    scatter(set2,set1,zip_size/200,'MarkerEdgeColor',color_RCT,'MarkerFaceColor',color_RCT,'LineWidth',1.5,'MarkerFaceAlpha',0.7); % RCT
    hold on;
    scatter(set3_onedim,set1,zip_size/200,'MarkerEdgeColor',color_univariate,'MarkerFaceColor',color_univariate,'LineWidth',1.5,'MarkerFaceAlpha',0.5); % EWM
    ylabel('Diversity index'); xlabel('Share treated');
    xlim([0 1]); ylim([0 1])
    legend({'Original experiment','One-dimensional rule'})
    figure_name=sprintf('inequality_race_%s_%s.png',s_cost,rules{1});
    saveas(comp_scatter,fullfile(savedir_figures,figure_name))
    